Paper reading May 2025
Self-attention is a key component of the Transformer architecture.
However, it scales quadratically with the sequence length, making it inefficient for long sequences.
There have been various attempts to improve the efficiency of attention mechanisms making them linear or near-linear.
Flash attention [1], [2] shows that the limiting factor in the performance of attention is the memory bandwidth.
We can reduce computation time by making the implemetation more I/O aware.
Attention forward and backwards pass on A100 GPU. [2]
Flash attention relies on a technique called “tiling” to break down the attention computation into smaller chunks that fit into the GPU’s memory.
This allows for efficient computation without the need for large intermediate tensors.
However, for tiling to work we need to have asociative operations, which is not the case for normal attention mechanism, as softmax is not associative.
The softmax function is a key component of the attention mechanism.
It is used to compute the attention weights, which scale the values based on the similarity of the queries and keys. \[ \text{Attention}(Q, K, V) = \text{Softmax}(\frac{QK^\mathsf{T}}{\sqrt{d_k}}) V, \text{where} \\ Q, K, V \in \mathbb{R}^{N \times d_k}, \\ N \text{ is the sequence length,}\\ d_k \text{ is the dimension of the heads.} \]
The softmax function is defined as for a vector \(x = \{x_i\}_{i=1}^N\in \mathbb{R}^N\) as follows: \[ \text{Softmax}(x) = \left\{\frac{e^{x_i}}{\sum_{j=1}^{N} e^{x_j}}\right\}_{i=1}^N\]
If any of the \(x_i\) is very large, the exponentials will overflow and the softmax will return NaN.
The limit for float32 is \(3.4028235 e^{38}\), which means that the softmax will overflow if any of the \(x_i\) is larger than \(88.722839\).
The naive implementation requires three passes over the data.
Lets introduce the following notation:
For \(i \leftarrow 1, N\) do \[ m_{i} \leftarrow \max \left(m_{i-1}, x_{i}\right) \qquad(1)\]
For \(i \leftarrow 1, N\) do \[ d_{i} \leftarrow d_{i-1}+e^{x_{i}-m_{N}} \qquad(2)\]
For \(i \leftarrow 1, N\) do \[ a_{i} \leftarrow \frac{e^{x_{i}-m_{N}}}{d_{N}} \qquad(3)\]
We want to fuse the operations in a single loop, however (Equation 2) and (Equation 3) cannot be fused as (Equation 3) depends on the value of \(m_N\) which is not known until the end of the loop.
We can create a surrogate sequence \(d'_{i}\) that is computed in the same way as \(d_i\) but does not depend on \(m_N\): \[ d'_{i} = \sum_{j=1}^i e^{x_j-m_i}. \]
Furthermore \(d_{N} = d'_{N}\), so we can replace \(d_N\) with \(d'_N\) in (Equation 3).
For \(i \leftarrow 1, N\) do \[ m_{i} \leftarrow \max \left(m_{i-1}, x_{i}\right)\\ d'_i \leftarrow d'_{i-1} e^{m_{i-1}-m_i} + e^{x_i-m_i} \qquad(4)\]
For \(i \leftarrow 1, N\) do \[ a_{i} \leftarrow \frac{e^{x_{i}-m_{N}}}{d'_{N}} \qquad(5)\]
While we can’t decrease the number of passes for softmax we can decrease the number of passes for the self-attention mechanism, by finding a one-pass recuurrence relation for the O matrix.
\(Q[k,:]\) is the k-th row vector of the \(Q\) matrix.
\(K^{T}[:, i]\) : the \(i\)-th column vector of \(K^{T}\) matrix.
\(O[k,:]\) : the \(k\)-th row of output \(O\) matrix.
\(V[i,:]\) : the \(i\)-th row of \(V\) matrix.
\(\left\{\boldsymbol{o}_{i}\right\}: \sum_{j=1}^{i} a_{j} V[j,:]\), a row vector storing partial aggregation result \(A[k,: i] \times V[: i,:]\)
For \(i \leftarrow 1, N\) do
\[\begin{aligned} x_{i} & \leftarrow Q[k,:] K^{T}[:, i] \\ m_{i} & \leftarrow \max \left(m_{i-1}, x_{i}\right) \\ d_{i}^{\prime} & \leftarrow d_{i-1}^{\prime} e^{m_{i-1}-m_{i}}+e^{x_{i}-m_{i}} \end{aligned}\]For \(i \leftarrow 1, N\) do \[ a_{i} \leftarrow \frac{e^{x_{i}-m_{N}}}{d_{N}^{\prime}} \qquad(6)\] \[ \boldsymbol{o}_{i} \leftarrow \boldsymbol{o}_{i-1}+a_{i} V[i,:] \qquad(7)\]
\[ O[k,:] \leftarrow \boldsymbol{o}_{N} \]
Lets replace the \(a_i\) in (Equation 6) with the recurrence relation from (Equation 7): \[ \boldsymbol{o}_{i}:=\sum_{j=1}^{i}\left(\frac{e^{x_{j}-m_{N}}}{d_{N}^{\prime}} V[j,:]\right) \]
Depends on \(m_N\) and \(d_N\), so again create a surrogate sequence \(\boldsymbol{o}'_{i}\): \[ \boldsymbol{o}_{i}^{\prime}:=\left(\sum_{j=1}^{i} \frac{e^{x_{j}-m_{i}}}{d_{i}^{\prime}} V[j,:]\right) \]
\[\begin{align*} \boldsymbol{o}_{i}^{\prime} & =\sum_{j=1}^{i} \frac{e^{x_{j}-m_{i}}}{d_{i}^{\prime}} V[j,:] \\ & =\left(\sum_{j=1}^{i-1} \frac{e^{x_{j}-m_{i}}}{d_{i}^{\prime}} V[j,:]\right)+\frac{e^{x_{i}-m_{i}}}{d_{i}^{\prime}} V[i,:] \\ & =\left(\sum_{j=1}^{i-1} \frac{e^{x_{j}-m_{i-1}}}{d_{i-1}^{\prime}} \frac{e^{x_{j}-m_{i}}}{e^{x_{j}-m_{i-1}}} \frac{d_{i-1}^{\prime}}{d_{i}^{\prime}} V[j,:]\right)+\frac{e^{x_{i}-m_{i}}}{d_{i}^{\prime}} V[i,:] \\ & =\left(\sum_{j=1}^{i-1} \frac{e^{x_{j}-m_{i-1}}}{d_{i-1}^{\prime}} V[j,:]\right) \frac{d_{i-1}^{\prime}}{d_{i}^{\prime}} e^{m_{i-1}-m_{i}}+\frac{e^{x_{i}-m_{i}}}{d_{i}^{\prime}} V[i,:] \\ \end{align*}\]
\[\begin{align*} \boldsymbol{o}_{i}^{\prime} & =\sum_{j=1}^{i} \frac{e^{x_{j}-m_{i}}}{d_{i}^{\prime}} V[j,:] \\ & =\boldsymbol{o}_{i-1}^{\prime} \frac{d_{i-1}^{\prime} e^{m_{i-1}-m_{i}}}{d_{i}^{\prime}}+\frac{e^{x_{i}-m_{i}}}{d_{i}^{\prime}} V[i,:] \end{align*}\]
Running code on CPU (Taken from CS149 Stanford)
Running code on GPU(Taken from CS149 Stanford)
CUDA thread hierarchy (Taken from CS149 Stanford)
CUDA memory model (Taken from CS149 Stanford)
CUDA types of memory (Taken from CS149 Stanford)
CUDA memory types (Taken from yt)
CUDA memory types (Taken from yt)
CUDA memory types [4]
CUDA synchronization (Taken from CS149 Stanford)
Thread memory access [4]
Breaking down the matrix into tiles [4]
Tiled execution order [4]
Tiled matmul kernel [4]
Flash attention computation on GPU [5]
Another graphical representation of Flash Attention [2]
Flash attention algorithm [1]
Flash attention 2 algorithm [2]